{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# JaX Flax --> ONNX\n",
    "\n",
    "This notebook converts brax MLP networks to an ONNX checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"MUJOCO_GL\"] = \"egl\"\n",
    "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from brax.training.agents.ppo import networks as ppo_networks\n",
    "from mujoco_playground.config import locomotion_params, manipulation_params\n",
    "from mujoco_playground import locomotion, manipulation\n",
    "import functools\n",
    "import pickle\n",
    "import jax.numpy as jp\n",
    "import jax\n",
    "import tf2onnx\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers\n",
    "import onnxruntime as rt\n",
    "from brax.training.acme import running_statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"BerkeleyHumanoidJoystickFlatTerrain\"\n",
    "# ppo_params = locomotion_params.brax_ppo_config(env_name)\n",
    "ppo_params = locomotion_params.brax_ppo_config(env_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def identity_observation_preprocessor(observation, preprocessor_params):\n",
    "  del preprocessor_params\n",
    "  return observation\n",
    "\n",
    "network_factory=functools.partial(\n",
    "  ppo_networks.make_ppo_networks,\n",
    "  **ppo_params.network_factory,\n",
    "  # We need to explicitly call the normalization function here since only the brax\n",
    "  # PPO train.py script creates it if normalize_observations is True.\n",
    "  preprocess_observations_fn=running_statistics.normalize,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# env_cfg = locomotion.get_default_config(env_name)\n",
    "# env = locomotion.load(env_name, config=env_cfg)\n",
    "env_cfg = locomotion.get_default_config(env_name)\n",
    "env = locomotion.load(env_name, config=env_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_size = env.observation_size\n",
    "act_size = env.action_size\n",
    "print(obs_size, act_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ppo_network = network_factory(obs_size, act_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-015111\"  # g1_fulljntrange\n",
    "# ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-021034\"  # g1_fulljntrange_energy\n",
    "# ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-022739\"  # g1_fulljntrange_energy_biggerangvel (bad)\n",
    "# ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-031036\"  # g1_fulljntrange_energy_biggerangvel_again (redo)\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-040028\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-041537\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-043559\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-045516\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-052430\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-135615\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-152932\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-185338\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickRoughTerrain-20250109-192738\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-232834\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-170057\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/G1JoystickFlatTerrain-20250109-235213\"\n",
    "\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/BerkeleyHumanoidJoystickFlatTerrain-20250110-231813\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/BerkeleyHumanoidJoystickFlatTerrain-20250110-233839\"\n",
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/BerkeleyHumanoidJoystickFlatTerrain-20250111-001442\"\n",
    "ckpt_path = ckpt_path + \"/params.pkl\"\n",
    "\n",
    "with open(ckpt_path, 'rb') as f:\n",
    "    params = pickle.load(f)\n",
    "print(params.keys())\n",
    "\n",
    "output_path = f\"bh_policy.onnx\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = (params[\"normalizer_params\"], params[\"policy_params\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "make_inference_fn = ppo_networks.make_inference_fn(ppo_network)\n",
    "inference_fn = make_inference_fn(params, deterministic=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(tf.keras.Model):\n",
    "    def __init__(\n",
    "        self,\n",
    "        layer_sizes,\n",
    "        activation=tf.nn.relu,\n",
    "        kernel_init=\"lecun_uniform\",\n",
    "        activate_final=False,\n",
    "        bias=True,\n",
    "        layer_norm=False,\n",
    "        mean_std=None,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.layer_sizes = layer_sizes\n",
    "        self.activation = activation\n",
    "        self.kernel_init = kernel_init\n",
    "        self.activate_final = activate_final\n",
    "        self.bias = bias\n",
    "        self.layer_norm = layer_norm\n",
    "\n",
    "        if mean_std is not None:\n",
    "            self.mean = tf.Variable(mean_std[0], trainable=False, dtype=tf.float32)\n",
    "            self.std = tf.Variable(mean_std[1], trainable=False, dtype=tf.float32)\n",
    "        else:\n",
    "            self.mean = None\n",
    "            self.std = None\n",
    "\n",
    "        self.mlp_block = tf.keras.Sequential(name=\"MLP_0\")\n",
    "        for i, size in enumerate(self.layer_sizes):\n",
    "            dense_layer = layers.Dense(\n",
    "                size,\n",
    "                activation=self.activation,\n",
    "                kernel_initializer=self.kernel_init,\n",
    "                name=f\"hidden_{i}\",\n",
    "                use_bias=self.bias,\n",
    "            )\n",
    "            self.mlp_block.add(dense_layer)\n",
    "            if self.layer_norm:\n",
    "                self.mlp_block.add(layers.LayerNormalization(name=f\"layer_norm_{i}\"))\n",
    "        if not self.activate_final and self.mlp_block.layers:\n",
    "            if hasattr(self.mlp_block.layers[-1], 'activation') and self.mlp_block.layers[-1].activation is not None:\n",
    "                self.mlp_block.layers[-1].activation = None\n",
    "\n",
    "        self.submodules = [self.mlp_block]\n",
    "\n",
    "    def call(self, inputs):\n",
    "        if isinstance(inputs, list):\n",
    "            inputs = inputs[0]\n",
    "        if self.mean is not None and self.std is not None:\n",
    "            print(self.mean.shape, self.std.shape)\n",
    "            inputs = (inputs - self.mean) / self.std\n",
    "        logits = self.mlp_block(inputs)\n",
    "        loc, _ = tf.split(logits, 2, axis=-1)\n",
    "        return tf.tanh(loc)\n",
    "\n",
    "def make_policy_network(\n",
    "    param_size,\n",
    "    mean_std,\n",
    "    hidden_layer_sizes=[256, 256],\n",
    "    activation=tf.nn.relu,\n",
    "    kernel_init=\"lecun_uniform\",\n",
    "    layer_norm=False,\n",
    "):\n",
    "    policy_network = MLP(\n",
    "        layer_sizes=list(hidden_layer_sizes) + [param_size],\n",
    "        activation=activation,\n",
    "        kernel_init=kernel_init,\n",
    "        layer_norm=layer_norm,\n",
    "        mean_std=mean_std,\n",
    "    )\n",
    "    return policy_network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = params[0].mean[\"state\"]\n",
    "std = params[0].std[\"state\"]\n",
    "\n",
    "# Convert mean/std jax arrays to tf tensors.\n",
    "mean_std = (tf.convert_to_tensor(mean), tf.convert_to_tensor(std))\n",
    "\n",
    "tf_policy_network = make_policy_network(\n",
    "    param_size=act_size * 2,\n",
    "    mean_std=mean_std,\n",
    "    hidden_layer_sizes=ppo_params.network_factory.policy_hidden_layer_sizes,\n",
    "    activation=tf.nn.swish,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example_input = tf.zeros((1, obs_size[\"state\"][0]))\n",
    "example_output = tf_policy_network(example_input)\n",
    "print(example_output.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "def transfer_weights(jax_params, tf_model):\n",
    "    \"\"\"\n",
    "    Transfer weights from a JAX parameter dictionary to the TensorFlow model.\n",
    "\n",
    "    Parameters:\n",
    "    - jax_params: dict\n",
    "      Nested dictionary with structure {block_name: {layer_name: {params}}}.\n",
    "      For example:\n",
    "      {\n",
    "        'CNN_0': {\n",
    "          'Conv_0': {'kernel': np.ndarray},\n",
    "          'Conv_1': {'kernel': np.ndarray},\n",
    "          'Conv_2': {'kernel': np.ndarray},\n",
    "        },\n",
    "        'MLP_0': {\n",
    "          'hidden_0': {'kernel': np.ndarray, 'bias': np.ndarray},\n",
    "          'hidden_1': {'kernel': np.ndarray, 'bias': np.ndarray},\n",
    "          'hidden_2': {'kernel': np.ndarray, 'bias': np.ndarray},\n",
    "        }\n",
    "      }\n",
    "\n",
    "    - tf_model: tf.keras.Model\n",
    "      An instance of the adapted VisionMLP model containing named submodules and layers.\n",
    "    \"\"\"\n",
    "    for layer_name, layer_params in jax_params.items():\n",
    "        try:\n",
    "            tf_layer = tf_model.get_layer(\"MLP_0\").get_layer(name=layer_name)\n",
    "        except ValueError:\n",
    "            print(f\"Layer {layer_name} not found in TensorFlow model.\")\n",
    "            continue\n",
    "        if isinstance(tf_layer, tf.keras.layers.Dense):\n",
    "            kernel = np.array(layer_params['kernel'])\n",
    "            bias = np.array(layer_params['bias'])\n",
    "            print(f\"Transferring Dense layer {layer_name}, kernel shape {kernel.shape}, bias shape {bias.shape}\")\n",
    "            tf_layer.set_weights([kernel, bias])\n",
    "        else:\n",
    "            print(f\"Unhandled layer type in {layer_name}: {type(tf_layer)}\")\n",
    "\n",
    "    print(\"Weights transferred successfully.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_weights(params[1]['params'], tf_policy_network)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example inputs for the model\n",
    "test_input = [np.ones((1, obs_size[\"state\"][0]), dtype=np.float32)]\n",
    "\n",
    "# Define the TensorFlow input signature\n",
    "spec = [tf.TensorSpec(shape=(1, obs_size[\"state\"][0]), dtype=tf.float32, name=\"obs\")]\n",
    "\n",
    "tensorflow_pred = tf_policy_network(test_input)[0]\n",
    "# Build the model by calling it with example data\n",
    "print(f\"Tensorflow prediction: {tensorflow_pred}\")\n",
    "\n",
    "tf_policy_network.output_names = ['continuous_actions']\n",
    "\n",
    "# opset 11 matches isaac lab.\n",
    "model_proto, _ = tf2onnx.convert.from_keras(tf_policy_network, input_signature=spec, opset=11, output_path=output_path)\n",
    "\n",
    "# Run inference with ONNX Runtime\n",
    "output_names = ['continuous_actions']\n",
    "providers = ['CPUExecutionProvider']\n",
    "m = rt.InferenceSession(output_path, providers=providers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_input = {\n",
    "  'obs': np.ones((1, obs_size[\"state\"][0]), dtype=np.float32)\n",
    "}\n",
    "# Prepare inputs for ONNX Runtime\n",
    "onnx_pred = m.run(output_names, onnx_input)[0][0]\n",
    "\n",
    "print(\"ONNX prediction:\", onnx_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_input = {\n",
    "    'state': jp.ones(obs_size[\"state\"]),\n",
    "    'privileged_state': jp.zeros(obs_size[\"privileged_state\"])\n",
    "}\n",
    "jax_pred, _ = inference_fn(test_input, jax.random.PRNGKey(0))\n",
    "print(jax_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "print(onnx_pred.shape)\n",
    "print(tensorflow_pred.shape)\n",
    "print(jax_pred.shape)\n",
    "plt.plot(onnx_pred, label='onnx')\n",
    "plt.plot(tensorflow_pred, label='tensorflow')\n",
    "plt.plot(jax_pred, label='jax')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# env_cfg = locomotion.get_default_config(env_name)\n",
    "# env = locomotion.load(env_name, config=env_cfg)\n",
    "# jit_reset = jax.jit(env.reset)\n",
    "# jit_step = jax.jit(env.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Test the policy.\n",
    "\n",
    "# # env_cfg = locomotion.get_default_config(env_name)\n",
    "# # env_cfg.init_from_crouch = 0.0\n",
    "# # env = locomotion.load(env_name, config=env_cfg)\n",
    "# # env_cfg = manipulation.get_default_config(env_name)\n",
    "# # env = manipulation.load(env_name, config=env_cfg)\n",
    "# # jit_reset = jax.jit(env.reset)\n",
    "# # jit_step = jax.jit(env.step)\n",
    "\n",
    "# x = 0.8\n",
    "# y = 0.0\n",
    "# yaw = 0.3\n",
    "# command = jp.array([x, y, yaw])\n",
    "# # actions = []\n",
    "\n",
    "# states = [state := jit_reset(jax.random.PRNGKey(555))]\n",
    "# state.info[\"command\"] = command\n",
    "# for _ in range(env_cfg.episode_length):\n",
    "#   onnx_input = {'obs': np.array(state.obs[\"state\"].reshape(1, -1))}\n",
    "#   action = m.run(output_names, onnx_input)[0][0]\n",
    "#   state = jit_step(state, jp.array(action))\n",
    "#   state.info[\"command\"] = command\n",
    "#   states.append(state)\n",
    "#   # actions.append(state.info[\"motor_targets\"])\n",
    "#   # actions.append(action)\n",
    "#   if state.done:\n",
    "#     print(\"Unexpected termination.\")\n",
    "#     break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import mediapy as media\n",
    "# fps = 1.0 / env.dt\n",
    "\n",
    "# frames = env.render(\n",
    "#     states,\n",
    "#     camera=\"track\",\n",
    "#     width=640*2,\n",
    "#     height=480*2,\n",
    "# )\n",
    "# media.show_video(frames, fps=fps, loop=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
